{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 下载c-eavl数据集\n",
    "\n",
    "```bash\n",
    "mkdir ceval-data\n",
    "cd ceval-data\n",
    "wget https://huggingface.co/datasets/ceval/ceval-exam/resolve/main/ceval-exam.zip \n",
    "unzip ceval-exam.zip -d ceval-exam\n",
    "wget https://github.com/hkust-nlp/ceval/blob/main/subject_mapping.json\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dev\n",
      "subject_mapping.json\n",
      "test\n",
      "val\n"
     ]
    }
   ],
   "source": [
    "! ls ceval-exam"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, re\n",
    "import ujson\n",
    "import torch\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
    "from transformers.generation.configuration_utils import GenerationConfig\n",
    "from transformers.generation.utils import LogitsProcessorList, InfNanRemoveLogitsProcessor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "ceval_dir = './ceval-exam'\n",
    "result_save_dir = './result'\n",
    "model_dir = '../model_save/dpo'  # 模型文件在上一层目录，使用dpo后的模型\n",
    "\n",
    "if not os.path.exists(result_save_dir):\n",
    "    os.mkdir(result_save_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "subject_files = os.listdir(f\"{ceval_dir}/val\")\n",
    "subjects = [subjetc.replace('_val.csv', '') for subjetc in subject_files]\n",
    "\n",
    "subject_mapping = {}\n",
    "with open('./ceval-exam/subject_mapping.json', 'r', encoding='utf-8') as f:\n",
    "    subject_mapping = ujson.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "由于本项目的模型在sft阶段删除了很多带input的数据，且没有针对问题回答做微调，直接输入问题会解释问题中提到的关键词。所以c-eval测试使用预测 'A'、'B'、'C'、'D' token的方式。\n",
    "> 然而有时候，特别是零样本测试和面对没有做过指令微调的模型时，模型可能无法很好的理解指令，甚至有时不会回答问题。这种情况下我们推荐直接计算下一个预测token等于\"A\", \"B\", \"C\", \"D\"的概率，然后以概率最大的选项作为答案 \n",
    "> -- 这是一种受限解码生成的方法，MMLU的官方测试代码中是使用了这种方法进行测试。注意这种概率方法对思维链的测试不适用。\n",
    "\n",
    "见： [如何在C-Eval上测试](https://github.com/hkust-nlp/ceval/blob/main/README_zh.md#如何在C-Eval上测试)\n",
    "\n",
    "评测模式：zero-shot模式（chatbot/对话机器人模式）  \n",
    "dev数据集用来做few-shot，暂时不用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_prompt(df: pd.Series) -> str:\n",
    "    '''\n",
    "    将df中的 'question', 'A', 'B', 'C', 'D',格式化为问题\n",
    "    '''\n",
    "    prompt = f\"请回答单选题，回答字母A、B、C、D即可。问题：\\n{df['question']}\\n答案选项：\\n\"\n",
    "    for col in ['A', 'B', 'C', 'D']:\n",
    "        prompt += f\"{col}：{df[col]}\\n\"\n",
    "    \n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Accountant', '注册会计师', 'Other']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "subject_mapping['accountant']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 52/52 [00:00<00:00, 617.74it/s]\n"
     ]
    }
   ],
   "source": [
    "do_test = False\n",
    "all_eval_items = []\n",
    "for i, subject_name in tqdm(enumerate(subjects), total=len(subjects)):\n",
    "    val_file = f\"{ceval_dir}/val/{subject_name}_val.csv\"\n",
    "    test_file = f\"{ceval_dir}/test/{subject_name}_test.csv\"\n",
    "\n",
    "    val_df = pd.read_csv(test_file) if do_test else pd.read_csv(val_file)\n",
    "    \n",
    "    for idx, row in val_df.iterrows():\n",
    "        quesuton = format_prompt(row)\n",
    "        answer = row['answer'] if 'answer' in val_df.columns else '' \n",
    "\n",
    "        item = {\n",
    "            'subject_en': subject_mapping[subject_name][0],\n",
    "            'subject_zh': subject_mapping[subject_name][1],\n",
    "            'category': subject_mapping[subject_name][2],  # 类别(STEM,Social Science,Humanities,Other四选一)\n",
    "            'question': quesuton,\n",
    "            'answer':answer,\n",
    "        }\n",
    "    \n",
    "        all_eval_items.append(item)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>subject_en</th>\n",
       "      <th>subject_zh</th>\n",
       "      <th>category</th>\n",
       "      <th>question</th>\n",
       "      <th>answer</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Accountant</td>\n",
       "      <td>注册会计师</td>\n",
       "      <td>Other</td>\n",
       "      <td>请回答单选题，回答字母A、B、C、D即可。问题：\\n下列关于税法基本原则的表述中，不正确的是...</td>\n",
       "      <td>D</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Accountant</td>\n",
       "      <td>注册会计师</td>\n",
       "      <td>Other</td>\n",
       "      <td>请回答单选题，回答字母A、B、C、D即可。问题：\\n甲公司是国内一家领先的新媒体、通信及移动...</td>\n",
       "      <td>C</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Accountant</td>\n",
       "      <td>注册会计师</td>\n",
       "      <td>Other</td>\n",
       "      <td>请回答单选题，回答字母A、B、C、D即可。问题：\\n根据我国《印花税暂行条例》的规定，下列各...</td>\n",
       "      <td>D</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Accountant</td>\n",
       "      <td>注册会计师</td>\n",
       "      <td>Other</td>\n",
       "      <td>请回答单选题，回答字母A、B、C、D即可。问题：\\n税务行政复议的申请人可以在得知税务机关作...</td>\n",
       "      <td>A</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Accountant</td>\n",
       "      <td>注册会计师</td>\n",
       "      <td>Other</td>\n",
       "      <td>请回答单选题，回答字母A、B、C、D即可。问题：\\n关于战略管理表述错误的是____。\\n答...</td>\n",
       "      <td>C</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   subject_en subject_zh category  \\\n",
       "0  Accountant      注册会计师    Other   \n",
       "1  Accountant      注册会计师    Other   \n",
       "2  Accountant      注册会计师    Other   \n",
       "3  Accountant      注册会计师    Other   \n",
       "4  Accountant      注册会计师    Other   \n",
       "\n",
       "                                            question answer  \n",
       "0  请回答单选题，回答字母A、B、C、D即可。问题：\\n下列关于税法基本原则的表述中，不正确的是...      D  \n",
       "1  请回答单选题，回答字母A、B、C、D即可。问题：\\n甲公司是国内一家领先的新媒体、通信及移动...      C  \n",
       "2  请回答单选题，回答字母A、B、C、D即可。问题：\\n根据我国《印花税暂行条例》的规定，下列各...      D  \n",
       "3  请回答单选题，回答字母A、B、C、D即可。问题：\\n税务行政复议的申请人可以在得知税务机关作...      A  \n",
       "4  请回答单选题，回答字母A、B、C、D即可。问题：\\n关于战略管理表述错误的是____。\\n答...      C  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval_df = pd.DataFrame(all_eval_items)\n",
    "eval_df.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[872, 873, 884, 886]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 加载模型\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_dir)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)\n",
    "\n",
    "generation_config = GenerationConfig()\n",
    "generation_config.remove_invalid_values = True  # 自动添加InfNanRemoveLogitsProcessor\n",
    "generation_config.eos_token_id = tokenizer.eos_token_id\n",
    "generation_config.pad_token_id = tokenizer.pad_token_id\n",
    "# for t5, set decoder_start_token_id = pad_token_id\n",
    "generation_config.decoder_start_token_id = tokenizer.pad_token_id  \n",
    "generation_config.max_new_tokens = 16\n",
    "generation_config.num_beams = 1\n",
    "generation_config.do_sample = False   # greedy search\n",
    "\n",
    "choices = ['A', 'B', 'C', 'D']\n",
    "choices_ids = [tokenizer.convert_tokens_to_ids(c) for c in choices]\n",
    "choices_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1346/1346 [00:20<00:00, 64.11it/s]\n"
     ]
    }
   ],
   "source": [
    "batch_size = 32\n",
    "batch_data, batch_answers = [], []\n",
    "n = len(eval_df)\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model.to(device)\n",
    "model.eval()\n",
    "\n",
    "for idx, row in tqdm(eval_df.iterrows(), total=n):\n",
    "    batch_data.append(row['question'])\n",
    "    \n",
    "    if len(batch_data) == batch_size or idx == n - 1:\n",
    "        torch.cuda.empty_cache()\n",
    "        \n",
    "        encode_ids = tokenizer(batch_data, padding=True)\n",
    "        input_ids, attention_mask = torch.LongTensor(encode_ids['input_ids']), torch.LongTensor(encode_ids['attention_mask'])\n",
    "        \n",
    "        outputs = model.generate(\n",
    "            input_ids=input_ids.to(device),\n",
    "            attention_mask=attention_mask.to(device),\n",
    "            generation_config=generation_config,\n",
    "            return_dict_in_generate=True,\n",
    "            output_scores=True,\n",
    "        )\n",
    "\n",
    "        scores = torch.stack(outputs['scores'], dim=1)\n",
    "        scores = torch.softmax(scores, dim=2)\n",
    "        scores = scores[...,  0, choices_ids]  #取第一个字符的ABCD概率\n",
    "        choices_index = torch.argmax(scores, dim=1)\n",
    "        \n",
    "        for i in choices_index:\n",
    "            batch_answers.append(choices[i])\n",
    "\n",
    "        batch_data = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_df.insert(loc=5, column='model_predict', value=batch_answers)\n",
    "val_df = eval_df.copy(deep=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_df['is_correct'] = val_df['model_predict'] == val_df['answer']\n",
    "val_df['is_correct'] = val_df['is_correct'].astype(pd.Int16Dtype())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>subject_en</th>\n",
       "      <th>subject_zh</th>\n",
       "      <th>category</th>\n",
       "      <th>question</th>\n",
       "      <th>answer</th>\n",
       "      <th>model_predict</th>\n",
       "      <th>is_correct</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Accountant</td>\n",
       "      <td>注册会计师</td>\n",
       "      <td>Other</td>\n",
       "      <td>请回答单选题，回答字母A、B、C、D即可。问题：\\n下列关于税法基本原则的表述中，不正确的是...</td>\n",
       "      <td>D</td>\n",
       "      <td>A</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Accountant</td>\n",
       "      <td>注册会计师</td>\n",
       "      <td>Other</td>\n",
       "      <td>请回答单选题，回答字母A、B、C、D即可。问题：\\n甲公司是国内一家领先的新媒体、通信及移动...</td>\n",
       "      <td>C</td>\n",
       "      <td>A</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Accountant</td>\n",
       "      <td>注册会计师</td>\n",
       "      <td>Other</td>\n",
       "      <td>请回答单选题，回答字母A、B、C、D即可。问题：\\n根据我国《印花税暂行条例》的规定，下列各...</td>\n",
       "      <td>D</td>\n",
       "      <td>A</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   subject_en subject_zh category  \\\n",
       "0  Accountant      注册会计师    Other   \n",
       "1  Accountant      注册会计师    Other   \n",
       "2  Accountant      注册会计师    Other   \n",
       "\n",
       "                                            question answer model_predict  \\\n",
       "0  请回答单选题，回答字母A、B、C、D即可。问题：\\n下列关于税法基本原则的表述中，不正确的是...      D             A   \n",
       "1  请回答单选题，回答字母A、B、C、D即可。问题：\\n甲公司是国内一家领先的新媒体、通信及移动...      C             A   \n",
       "2  请回答单选题，回答字母A、B、C、D即可。问题：\\n根据我国《印花税暂行条例》的规定，下列各...      D             A   \n",
       "\n",
       "   is_correct  \n",
       "0           0  \n",
       "1           0  \n",
       "2           0  "
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_df.head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>is_correct</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>category</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Humanities</th>\n",
       "      <td>63</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Other</th>\n",
       "      <td>89</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>STEM</th>\n",
       "      <td>89</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Social Science</th>\n",
       "      <td>72</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                is_correct\n",
       "category                  \n",
       "Humanities              63\n",
       "Other                   89\n",
       "STEM                    89\n",
       "Social Science          72"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final_df =  val_df.groupby('category').sum('is_correct')\n",
    "final_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>is_correct</th>\n",
       "      <th>question_count</th>\n",
       "      <th>accuracy</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>category</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Humanities</th>\n",
       "      <td>63</td>\n",
       "      <td>257</td>\n",
       "      <td>24.51%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Other</th>\n",
       "      <td>89</td>\n",
       "      <td>384</td>\n",
       "      <td>23.18%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>STEM</th>\n",
       "      <td>89</td>\n",
       "      <td>430</td>\n",
       "      <td>20.70%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Social Science</th>\n",
       "      <td>72</td>\n",
       "      <td>275</td>\n",
       "      <td>26.18%</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                is_correct  question_count accuracy\n",
       "category                                           \n",
       "Humanities              63             257   24.51%\n",
       "Other                   89             384   23.18%\n",
       "STEM                    89             430   20.70%\n",
       "Social Science          72             275   26.18%"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final_df['question_count'] =  val_df.groupby('category').count()['question']\n",
    "final_df['accuracy'] = final_df['is_correct'] / final_df['question_count']\n",
    "final_df['accuracy']  = final_df['accuracy'] .apply(lambda x: format(x, '.2%'))\n",
    "final_df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
